from .config import CfgNode as CN

_C = CN()

_C.MODEL = CN()
_C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "ACAIModel"
_C.MODEL.PIXEL_MEAN = [0.]  # or [0.5, 0.5, 0.5] if INPUT.FORMAT = "RGB" for example
_C.MODEL.PIXEL_STD = [1.]  # or [0.5, 0.5, 0.5] if INPUT.FORMAT = "RGB" for example
_C.MODEL.IGNORE_INDEX = -100  # ignore_index for cross entropy loss

_C.INPUT = CN()
_C.INPUT.FORMAT = "L"  # RGB or L (grayscale)
_C.INPUT.N_FRAMES_PER_VIDEO_TRAIN = -1
_C.INPUT.N_FRAMES_PER_VIDEO_TEST = -1  # take first n frames from test video
_C.INPUT.SCALE_TO_ZEROONE = True
_C.INPUT.PREPARE_SLICES_TRAIN = False

_C.GAN_MODE_ON = False

_C.DATASETS = CN()
_C.DATASETS.TRAIN = ()
_C.DATASETS.TEST = ()

_C.DATALOADER = CN()
_C.DATALOADER.NUM_WORKERS = 4
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"

# network initialization [normal|xavier|kaiming|orthogonal]
_C.MODEL.INIT_TYPE = 'normal'
# variance of the initialization distribution
_C.MODEL.INIT_VARIANCE = 0.02

_C.MODEL.AUTOREGRESSIVE = CN()
_C.MODEL.AUTOREGRESSIVE.NAME = ""
_C.MODEL.AUTOREGRESSIVE.VT = CN()
_C.MODEL.AUTOREGRESSIVE.VT.NC = 0
_C.MODEL.AUTOREGRESSIVE.VT.NV = 0
_C.MODEL.AUTOREGRESSIVE.VT.KERNEL = ()
_C.MODEL.AUTOREGRESSIVE.VT.STRIDE = ()
_C.MODEL.AUTOREGRESSIVE.VT.D = 0
_C.MODEL.AUTOREGRESSIVE.VT.DA = 0
_C.MODEL.AUTOREGRESSIVE.VT.DE = 0
_C.MODEL.AUTOREGRESSIVE.VT.BLOCKS_E = ()
_C.MODEL.AUTOREGRESSIVE.VT.N_HEAD_E = ()
_C.MODEL.AUTOREGRESSIVE.VT.BLOCKS_D = ()
_C.MODEL.AUTOREGRESSIVE.VT.N_HEAD_D = ()
_C.MODEL.AUTOREGRESSIVE.VT.N_PRIME = 0
_C.MODEL.AUTOREGRESSIVE.VT.PAD_VALUE = -1
_C.MODEL.AUTOREGRESSIVE.VT.SHARE_P = True
_C.MODEL.AUTOREGRESSIVE.VT.SHARE_EMBEDDINGS = False
_C.MODEL.AUTOREGRESSIVE.VT.CLASS_NUM = 0

_C.MODEL.ENCODER = CN()
_C.MODEL.ENCODER.WEIGHTS = ""
_C.MODEL.ENCODER.NAME = ""
_C.MODEL.ENCODER.IN_CHANNELS = 1
_C.MODEL.ENCODER.NF = 16
_C.MODEL.ENCODER.RES_CHANNELS = 0
_C.MODEL.ENCODER.OUT_CHANNELS = 16
_C.MODEL.ENCODER.NORM = ""
_C.MODEL.ENCODER.N_LAYERS = 0
_C.MODEL.ENCODER.SPECTRAL = False
_C.MODEL.ENCODER.OUT_ACTIVATION = ""

_C.MODEL.GENERATOR = CN()
_C.MODEL.GENERATOR.WEIGHTS = ""
_C.MODEL.GENERATOR.NAME = ""
_C.MODEL.GENERATOR.IN_CHANNELS = 16
_C.MODEL.GENERATOR.NF = 16
_C.MODEL.GENERATOR.RES_CHANNELS = 0
_C.MODEL.GENERATOR.OUT_CHANNELS = 3
_C.MODEL.GENERATOR.NORM = ""
_C.MODEL.GENERATOR.N_LAYERS = 0
_C.MODEL.GENERATOR.SPECTRAL = False
_C.MODEL.GENERATOR.OUT_ACTIVATION = ""

_C.MODEL.CODEBOOK = CN()
_C.MODEL.CODEBOOK.NUM = 1
_C.MODEL.CODEBOOK.SIZE = 512
_C.MODEL.CODEBOOK.DIM = 256
_C.MODEL.CODEBOOK.WEIGHTS = ""
_C.MODEL.CODEBOOK.EMA = False
_C.MODEL.CODEBOOK.BETA = 1.0

_C.SOLVER = CN()
_C.SOLVER.MAX_ITER = 40000
_C.SOLVER.SUPERVISED_MAX_ITER = -1

# LR scheduler configuration
# for now G and D have the same schedulers
_C.SOLVER.LR_SCHEDULER_NAME = "Identity"
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.STEPS = ()
_C.SOLVER.WARMUP_ITERS = -1
_C.SOLVER.WARMUP_FACTOR = 0.01
_C.SOLVER.WARMUP_METHOD = "linear"

# Optimizer configuration
_C.SOLVER.OPTIMIZER_NAME = "adam"
_C.SOLVER.LR_G = 0.0001
_C.SOLVER.LR_D = 0.0004

_C.SOLVER.WEIGHT_DECAY = CN()
_C.SOLVER.WEIGHT_DECAY.BASE_G = 0.0
_C.SOLVER.WEIGHT_DECAY.BIAS_G = 0.0
_C.SOLVER.WEIGHT_DECAY.NORM_G = 0.0
_C.SOLVER.WEIGHT_DECAY.BASE_D = 0.0
_C.SOLVER.WEIGHT_DECAY.BIAS_D = 0.0
_C.SOLVER.WEIGHT_DECAY.NORM_D = 0.0

_C.SOLVER.ADAM = CN()
_C.SOLVER.ADAM.BETA1_G = 0.9
_C.SOLVER.ADAM.BETA2_G = 0.9
_C.SOLVER.ADAM.BETA1_D = 0.9
_C.SOLVER.ADAM.BETA2_D = 0.999

_C.SOLVER.RMSPROP = CN()
_C.SOLVER.RMSPROP.ALPHA_G = 0.99
_C.SOLVER.RMSPROP.ALPHA_D = 0.99
_C.SOLVER.RMSPROP.MOMENTUM_G = 0.0
_C.SOLVER.RMSPROP.MOMENTUM_D = 0.0

_C.SOLVER.ACCUMULATION_STEPS = 1

_C.SOLVER.CHECKPOINT_PERIOD = 50000
_C.SOLVER.IMS_PER_BATCH = 32  # batch size
_C.SOLVER.D_UPDATE_RATIO = 1
_C.SOLVER.D_INIT_ITERS = -1
_C.SOLVER.MAXUP = False

_C.LOSS = CN()

_C.LOSS.PIXEL = CN()
_C.LOSS.PIXEL.ONN = False
_C.LOSS.PIXEL.LAMBDA = 1.0
_C.LOSS.PIXEL.MODE = "l2"  # l1 | l2

_C.LOSS.GAN = CN()
_C.LOSS.GAN.ONN = False
_C.LOSS.GAN.LAMBDA_G = 1.0
_C.LOSS.GAN.LAMBDA_D = 1.0
_C.LOSS.GAN.REAL_LABEL = 1.0
_C.LOSS.GAN.FAKE_LABEL = 0.0
_C.LOSS.GAN.MODE = "wgan"

_C.TEST = CN()
_C.TEST.EXPECTED_RESULTS = []
_C.TEST.EVAL_PERIOD = 0
_C.TEST.N_SAMPLES = 0
_C.TEST.EVALUATORS = ""
_C.TEST.VT_SAMPLER = CN()
_C.TEST.VT_SAMPLER.VQ_VAE = CN()
_C.TEST.VT_SAMPLER.VQ_VAE.CFG = ""
_C.TEST.VT_SAMPLER.VQ_VAE.ENCODER_WEIGHTS = ""
_C.TEST.VT_SAMPLER.VQ_VAE.GENERATOR_WEIGHTS = ""
_C.TEST.VT_SAMPLER.VQ_VAE.CODEBOOK_WEIGHTS = ""
_C.TEST.VT_SAMPLER.N_PRIME = 5
_C.TEST.VT_SAMPLER.NUM_SAMPLES = 10

_C.OUTPUT_DIR = "./output"
# _C.NO_HTML
_C.SEED = -1
_C.CUDNN_BENCHMARK = True
_C.VIS_PERIOD = 100000000000

_C.VERSION = 1

_C.GLOBAL = CN()
_C.GLOBAL.HACK = 1.0
